--- ../../plenoxels/svox2/svox2.py	2024-02-13 12:06:07.930166019 -0500
+++ ../google/plenoxels/svox2/svox2.py	2024-02-13 10:46:24.456645151 -0500
@@ -678,7 +678,7 @@
         invdirs = 1.0 / dirs
 
         gsz = self._grid_size()
-        gsz_cu = gsz.to(device=dirs.device)
+        gsz_cu = gsz.to(device=dirs.device).float()
         t1 = (-0.5 - origins) * invdirs
         t2 = (gsz_cu - 0.5 - origins) * invdirs
 
@@ -713,16 +713,16 @@
         while good_indices.numel() > 0:
             pos = origins + t[:, None] * dirs
             pos = pos.clamp_min_(0.0)
-            pos[:, 0] = torch.clamp_max(pos[:, 0], gsz[0] - 1)
-            pos[:, 1] = torch.clamp_max(pos[:, 1], gsz[1] - 1)
-            pos[:, 2] = torch.clamp_max(pos[:, 2], gsz[2] - 1)
+            pos[:, 0] = torch.clamp_max(pos[:, 0], gsz_cu[0] - 1)
+            pos[:, 1] = torch.clamp_max(pos[:, 1], gsz_cu[1] - 1)
+            pos[:, 2] = torch.clamp_max(pos[:, 2], gsz_cu[2] - 1)
             #  print('pym', pos, log_light_intensity)
 
             l = pos.to(torch.long)
             l.clamp_min_(0)
-            l[:, 0] = torch.clamp_max(l[:, 0], gsz[0] - 2)
-            l[:, 1] = torch.clamp_max(l[:, 1], gsz[1] - 2)
-            l[:, 2] = torch.clamp_max(l[:, 2], gsz[2] - 2)
+            l[:, 0] = torch.clamp_max(l[:, 0], gsz_cu[0].to(torch.long) - 2)
+            l[:, 1] = torch.clamp_max(l[:, 1], gsz_cu[1].to(torch.long) - 2)
+            l[:, 2] = torch.clamp_max(l[:, 2], gsz_cu[2].to(torch.long) - 2)
             pos -= l
 
             # BEGIN CRAZY TRILERP
@@ -942,15 +942,15 @@
         while good_indices.numel() > 0:
             pos = origins + t[:, None] * dirs
             pos = pos.clamp_min_(0.0)
-            pos[:, 0] = torch.clamp_max(pos[:, 0], gsz[0] - 1)
-            pos[:, 1] = torch.clamp_max(pos[:, 1], gsz[1] - 1)
-            pos[:, 2] = torch.clamp_max(pos[:, 2], gsz[2] - 1)
+            pos[:, 0] = torch.clamp_max(pos[:, 0], gsz_cu[0] - 1)
+            pos[:, 1] = torch.clamp_max(pos[:, 1], gsz_cu[1] - 1)
+            pos[:, 2] = torch.clamp_max(pos[:, 2], gsz_cu[2] - 1)
 
             l = pos.to(torch.long)
             l.clamp_min_(0)
-            l[:, 0] = torch.clamp_max(l[:, 0], gsz[0] - 2)
-            l[:, 1] = torch.clamp_max(l[:, 1], gsz[1] - 2)
-            l[:, 2] = torch.clamp_max(l[:, 2], gsz[2] - 2)
+            l[:, 0] = torch.clamp_max(l[:, 0], gsz_cu[0] - 2)
+            l[:, 1] = torch.clamp_max(l[:, 1], gsz_cu[1] - 2)
+            l[:, 2] = torch.clamp_max(l[:, 2], gsz_cu[2] - 2)
             pos -= l
 
             # BEGIN CRAZY TRILERP
@@ -1144,7 +1144,8 @@
     def volume_render_image(
         self, camera: Camera, use_kernel: bool = True, randomize: bool = False,
         batch_size : int = 5000,
-        return_raylen: bool=False
+        return_raylen: bool=False,
+        debug: bool=False
     ):
         """
         Standard volume rendering (entire image version).
@@ -1156,7 +1157,7 @@
         :return: (H, W, 3), predicted RGB image
         """
         imrend_fn_name = f"volume_render_{self.opt.backend}_image"
-        if self.basis_type != BASIS_TYPE_MLP and imrend_fn_name in _C.__dict__ and not torch.is_grad_enabled() and not return_raylen:
+        if self.basis_type != BASIS_TYPE_MLP and imrend_fn_name in _C.__dict__ and not torch.is_grad_enabled() and not return_raylen and not debug:
             # Use the fast image render kernel if available
             cu_fn = _C.__dict__[imrend_fn_name]
             return cu_fn(
@@ -1168,6 +1169,7 @@
             # Manually generate rays for now
             rays = camera.gen_rays()
             all_rgb_out = []
+            #print("Rendering image...")
             for batch_start in range(0, camera.height * camera.width, batch_size):
                 rgb_out_part = self.volume_render(rays[batch_start:batch_start+batch_size],
                                                   use_kernel=use_kernel,
